{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b8dbce5d-f3e3-4c82-92ee-0f64b83e51bb",
   "metadata": {},
   "source": [
    "# Infer-Retrieve-Rerank Llama Pack\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/run-llama/llama-hub/blob/main/llama_hub/llama_packs/research/infer_retrieve_rerank/infer_retrieve_rerank.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "\n",
    "This is our implementation of the paper [\"In-Context Learning for Extreme Multi-Label Classification](https://arxiv.org/pdf/2401.12178.pdf) by Oosterlinck et al.\n",
    "\n",
    "The paper proposes \"infer-retrieve-rerank\", a simple paradigm using frozen LLM/retriever models that can do \"extreme\"-label classification (the label space is huge).\n",
    "1. Given a user query, use an LLM to predict an initial set of labels.\n",
    "2. For each prediction, retrieve the actual label from the corpus.\n",
    "3. Given the final set of labels, rerank them using an LLM.\n",
    "\n",
    "All of these can be implemented as LlamaIndex abstractions. In this notebook we show you how to build \"infer-retrieve-rerank\" from scratch but also how to build it as a LlamaPack."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97d245ce-1f4a-45f8-a608-4941c8cb94b8",
   "metadata": {},
   "source": [
    "## Try out a Dataset\n",
    "\n",
    "We use the BioDEX dataset as mentioned in the paper.\n",
    "\n",
    "Here is the [link to the paper](https://arxiv.org/pdf/2305.13395.pdf). Here is the [link to the Github repo](https://github.com/KarelDO/BioDEX)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26bd6df6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-postprocessor-rankgpt-rerank\n",
    "%pip install llama-index-embeddings-openai\n",
    "%pip install llama-index-packs-infer-retrieve-rerank\n",
    "%pip install llama-index-llms-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e62fa4b-7322-4381-9703-c3c090b6eaf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jerryliu/Programming/llama-hub/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "\n",
    "# load the report-extraction dataset\n",
    "dataset = datasets.load_dataset(\"BioDEX/BioDEX-ICSR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f0ea579-f534-42de-b7d9-79f7260dda2b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],\n",
       "        num_rows: 9624\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],\n",
       "        num_rows: 2407\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],\n",
       "        num_rows: 3628\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07f108bc-d06e-4a93-9b69-1bc5eb09c1e2",
   "metadata": {},
   "source": [
    "### Define Dataset Processing Functions\n",
    "\n",
    "Here we define some basic functions to get the set of reactions (labels) and samples from the BioDEX dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cecfa23d-f263-44e3-bd07-7aa0ccbc2c04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import get_tokenizer\n",
    "import re\n",
    "from typing import Set, List\n",
    "\n",
    "tokenizer = get_tokenizer()\n",
    "\n",
    "\n",
    "sample_size = 5\n",
    "\n",
    "\n",
    "def get_reactions_row(raw_target: str) -> List[str]:\n",
    "    \"\"\"Get reactions from a single row.\"\"\"\n",
    "    reaction_pattern = re.compile(r\"reactions:\\s*(.*)\")\n",
    "    reaction_match = reaction_pattern.search(raw_target)\n",
    "    if reaction_match:\n",
    "        reactions = reaction_match.group(1).split(\",\")\n",
    "        reactions = [r.strip().lower() for r in reactions]\n",
    "    else:\n",
    "        reactions = []\n",
    "    return reactions\n",
    "\n",
    "\n",
    "def get_reactions_set(dataset) -> Set[str]:\n",
    "    \"\"\"Get set of all reactions.\"\"\"\n",
    "    reactions = set()\n",
    "    for data in dataset[\"train\"]:\n",
    "        reactions.update(set(get_reactions_row(data[\"target\"])))\n",
    "    return reactions\n",
    "\n",
    "\n",
    "def get_samples(dataset, sample_size: int = 5):\n",
    "    \"\"\"Get processed sample.\n",
    "\n",
    "    Contains source text and also the reaction label.\n",
    "\n",
    "    Parse reaction text to specifically extract reactions.\n",
    "\n",
    "    \"\"\"\n",
    "    samples = []\n",
    "    for idx, data in enumerate(dataset[\"train\"]):\n",
    "        if idx >= sample_size:\n",
    "            break\n",
    "        text = data[\"fulltext_processed\"]\n",
    "        raw_target = data[\"target\"]\n",
    "\n",
    "        reactions = get_reactions_row(raw_target)\n",
    "\n",
    "        samples.append({\"text\": text, \"reactions\": reactions})\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e157ac22-74e6-4701-8dd5-45bc88468061",
   "metadata": {},
   "source": [
    "## Use LlamaPack\n",
    "\n",
    "In this first section we use our infer-retrieve-rerank LlamaPack to output predicted labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b72c15e3-6888-405d-bcfe-7e73a1e1577c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option: if developing with the llama_hub package\n",
    "from llama_index.packs.infer_retrieve_rerank import InferRetrieveRerankPack\n",
    "\n",
    "# # Option: download_llama_pack\n",
    "from llama_index.core.llama_pack import download_llama_pack\n",
    "\n",
    "InferRetrieveRerankPack = download_llama_pack(\n",
    "    \"InferRetrieveRerankPack\",\n",
    "    \"./irr_pack\",\n",
    "    # leave the below line commented out if using the notebook on main\n",
    "    # llama_hub_url=\"https://raw.githubusercontent.com/run-llama/llama-hub/jerry/add_infer_retrieve_rerank/llama_hub\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc88294-fe3f-4f69-a06f-3415e901dd36",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating embeddings: 0it [00:00, ?it/s]\n",
      "Generating embeddings: 0it [00:00, ?it/s]\n",
      "Generating embeddings: 0it [00:00, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo-16k\")\n",
    "pred_context = \"\"\"\\\n",
    "The output predictins should be a list of comma-separated adverse \\\n",
    "drug reactions. \\\n",
    "\"\"\"\n",
    "reranker_top_n = 10\n",
    "\n",
    "pack = InferRetrieveRerankPack(\n",
    "    get_reactions_set(dataset),\n",
    "    llm=llm,\n",
    "    pred_context=pred_context,\n",
    "    reranker_top_n=reranker_top_n,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ba3a80-019d-4f1c-8d87-8635ebbc6af3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Generating predictions for input 0: TITLE:\n",
      "SARS-CoV-2-related ARDS in a maintenance hemodialysis patient: case report on tailored approach by daily hemodialysis, noninvasive ventilation, tocilizumab, anxiolytics, and point-of-care ultrasound.\n",
      "\n",
      "ABSTRACT:\n",
      "Without rescue drugs approved, holistic approach by daily hemodialysis, noninvasiv\n",
      "> Generated predictions: ['respiratory distress', 'fluid overload', 'fluid retention', 'anxiety', 'delirium', 'nervousness', 'acute myocardial infarction', 'cardiovascular insufficiency', 'neonatal respiratory distress syndrome', 'delirium tremens']\n",
      "> Generating predictions for input 1: TITLE:\n",
      "Corynebacterium propinquum: A Rare Cause of Prosthetic Valve Endocarditis.\n",
      "\n",
      "ABSTRACT:\n",
      "Nondiphtheria Corynebacterium species are often dismissed as culture contaminants, but they have recently become increasingly recognized as pathologic organisms. We present the case of a 48-year-old male pat\n",
      "> Generated predictions: ['chest pain', 'dyspnoea', 'dyspnoea exertional', 'dizziness', 'vertigo', 'palpitations', 'chest discomfort', 'pyrexia', 'dengue fever', 'crepitations']\n",
      "> Generating predictions for input 2: TITLE:\n",
      "A Case of Pancytopenia with Many Possible Causes: How Do You Tell Which is the Right One?\n",
      "\n",
      "ABSTRACT:\n",
      "Systemic lupus erythematosus (SLE) often presents with cytopenia(s); however, pancytopenia is found less commonly, requiring the consideration of possible aetiologies other than the primary di\n",
      "> Generated predictions: ['agranulocytosis', 'haematotoxicity', 'bone marrow toxicity', 'infantile genetic agranulocytosis']\n",
      "> Generating predictions for input 3: TITLE:\n",
      "Hepatic Lesions with Secondary Syphilis in an HIV-Infected Patient.\n",
      "\n",
      "ABSTRACT:\n",
      "Syphilis among HIV-infected patients continues to be a public health concern, especially in men who have sex with men. The clinical manifestations of syphilis are protean; syphilitic hepatitis is an unusual complic\n",
      "> Generated predictions: ['adverse drug reaction', 'adverse drug reaction', 'idiosyncratic drug reaction', 'no adverse event']\n",
      "> Generating predictions for input 4: TITLE:\n",
      "Managing Toe Walking, a Treatment Side Effect, in a Child With T-Cell Non-Hodgkin's Lymphoma: A Case Report.\n",
      "\n",
      "ABSTRACT:\n",
      "Background and Purpose: Children who have survived cancer are at risk of experiencing adverse effects of the cancer or its treatments. One of the adverse effects may be the \n",
      "> Generated predictions: ['toe walking', 'muscular weakness', 'gait disturbance', 'bone deformity', 'joint contracture', 'joint ankylosis', 'loss of anatomical alignment after fracture reduction', 'gait inability', 'peripheral sensory neuropathy', 'myopathy']\n"
     ]
    }
   ],
   "source": [
    "samples = get_samples(dataset, sample_size=5)\n",
    "pred_reactions = pack.run(inputs=[s[\"text\"] for s in samples])\n",
    "gt_reactions = [s[\"reactions\"] for s in samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf548ddf-b877-4b32-be9e-7fe7a8b3d809",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['agranulocytosis',\n",
       " 'haematotoxicity',\n",
       " 'bone marrow toxicity',\n",
       " 'infantile genetic agranulocytosis']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_reactions[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f796c4d9-9a34-4a5c-9ccf-79e78faae1b0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['bone marrow toxicity',\n",
       " 'cytomegalovirus infection',\n",
       " 'cytomegalovirus mucocutaneous ulcer',\n",
       " 'febrile neutropenia',\n",
       " 'leukoplakia',\n",
       " 'odynophagia',\n",
       " 'oropharyngeal candidiasis',\n",
       " 'pancytopenia',\n",
       " 'product use issue',\n",
       " 'red blood cell poikilocytes present',\n",
       " 'vitamin d deficiency']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_reactions[2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78a6a677-8306-4b8f-b009-3a1992e9ad22",
   "metadata": {},
   "source": [
    "## Define Infer-Retrieve-Rerank Pipeline\n",
    "\n",
    "Here we define the core components needed for the full infer-retrieve-rerank pipeline. \n",
    "\n",
    "Refer to the [paper](https://arxiv.org/pdf/2401.12178.pdf) for more details. The paper implements it in DSPy, here we adapt an implementation with LlamaIndex abstractions. As a result the specific implementations (e.g. prompts, output parsing modules, reranking module) are different even though the conceptually we follow similar steps.\n",
    "\n",
    "Our implementation uses fixed models, and does not do automatic distillation between teacher and student."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2790e307-7e48-4f6b-ac30-ca261c98a11a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.retrievers import BaseRetriever\n",
    "from llama_index.core.llms import LLM\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core import PromptTemplate\n",
    "from llama_index.core.query_pipeline import QueryPipeline\n",
    "from llama_index.core.postprocessor.types import BaseNodePostprocessor\n",
    "from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank\n",
    "from llama_index.core.output_parsers import ChainableOutputParser\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aca770f-8a51-4206-a11c-dc59b1b90a1c",
   "metadata": {},
   "source": [
    "#### Index each Reaction with a Vector Index\n",
    "\n",
    "Since the set of reactions is quite large, we can define a vector index over all reactions. That way we can retrieve the top k most semantically similar reactions to any prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6f7ffea-d698-4fee-8609-92326a0bb7ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/1r/c3h91d9s49xblwfvz79s78_c0000gn/T/ipykernel_83751/1748569963.py:4: DeprecationWarning: Sampling from a set deprecated\n",
      "since Python 3.9 and will be removed in a subsequent version.\n",
      "  random.sample(all_reactions, 5)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['burning mouth syndrome',\n",
       " 'hepatitis e',\n",
       " 'gingivitis ulcerative',\n",
       " 'page kidney',\n",
       " 'herpes simplex pneumonia']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "all_reactions = get_reactions_set(dataset)\n",
    "random.sample(all_reactions, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c15f5ce-eb3c-4776-b7ad-47f0e08f06de",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.schema import TextNode\n",
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "from llama_index.core.ingestion import IngestionPipeline\n",
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "reaction_nodes = [TextNode(text=r) for r in all_reactions]\n",
    "pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()])\n",
    "reaction_nodes = await pipeline.arun(documents=reaction_nodes)\n",
    "\n",
    "index = VectorStoreIndex(reaction_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d3314e-06e7-46aa-aa70-cd726238cffc",
   "metadata": {},
   "outputs": [],
   "source": [
    "reaction_nodes[0].embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a9a312d-8197-429e-a766-55edcbcb56fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "reaction_retriever = index.as_retriever(similarity_top_k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d5f304-4e42-491d-a81e-c3a8c47e008f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['abdominal pain', 'abdominal symptom']\n"
     ]
    }
   ],
   "source": [
    "nodes = reaction_retriever.retrieve(\"abdominal\")\n",
    "print([n.get_content() for n in nodes])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93cf6707-64f7-4e45-8998-fd5b5a767656",
   "metadata": {},
   "source": [
    "#### Define Infer Prompt\n",
    "\n",
    "We define an infer prompt that given a document and relevant task context, can generate a list of comma-separated predictions.\n",
    "\n",
    "**NOTE**: This is our own prompt and not taken from the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b94016bc-8af6-4714-b713-1662ec127c0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "infer_prompt_str = \"\"\"\\\n",
    "\n",
    "Your job is to output a list of predictions given context from a given piece of text. The text context,\n",
    "and information regarding the set of valid predictions is given below. \n",
    "\n",
    "Return the predictions as a comma-separated list of strings.\n",
    "\n",
    "Text Context:\n",
    "{doc_context}\n",
    "\n",
    "Prediction Info:\n",
    "{pred_context}\n",
    "\n",
    "Predictions: \"\"\"\n",
    "\n",
    "infer_prompt = PromptTemplate(infer_prompt_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1650e3d2-1cc7-4c05-bb6c-d0accf28b4f5",
   "metadata": {},
   "source": [
    "#### Define Output Parser\n",
    "\n",
    "We define a very simple output parser that can parse an output into a list of strings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3589b9e6-cfb5-4c60-800d-923ea042bf98",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PredsOutputParser(ChainableOutputParser):\n",
    "    \"\"\"Predictions output parser.\"\"\"\n",
    "\n",
    "    def parse(self, output: str) -> List[str]:\n",
    "        \"\"\"Parse predictions.\"\"\"\n",
    "        tokens = output.split(\",\")\n",
    "        return [t.strip() for t in tokens]\n",
    "\n",
    "\n",
    "preds_output_parser = PredsOutputParser()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b634b8dc-8780-45c6-bf21-0b1febdb5e0f",
   "metadata": {},
   "source": [
    "#### Define Rerank Prompt\n",
    "\n",
    "Here we define a rerank prompt that will reorder a batch of labels based on their relevance to the query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b29811f-356e-433d-a159-ebc711d5f5bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "rerank_str = \"\"\"\\\n",
    "Given a piece of text, rank the {num} labels above based on their relevance \\\n",
    "to this piece of text. The labels \\\n",
    "should be listed in descending order using identifiers. \\\n",
    "The most relevant labels should be listed first. \\\n",
    "The output format should be [] > [], e.g., [1] > [2]. \\\n",
    "Only response the ranking results, \\\n",
    "do not say any word or explain. \\\n",
    "\n",
    "Here is a given piece of text: {query}. \n",
    "\n",
    "\"\"\"\n",
    "rerank_prompt = PromptTemplate(rerank_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bd6f9ee-91cd-4b38-89b0-df2a72f007af",
   "metadata": {},
   "source": [
    "#### Define Infer-Retrieve-Rerank Function\n",
    "\n",
    "We define the infer-retrieve-rerank steps as a function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c9a4ce-28d8-4bc1-a4f6-335328c7c596",
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer_retrieve_rerank(\n",
    "    query: str,\n",
    "    retriever: BaseRetriever,\n",
    "    llm: LLM,\n",
    "    pred_context: str,\n",
    "    reranker_top_n: int = 3,\n",
    "):\n",
    "    \"\"\"Infer retrieve rerank.\"\"\"\n",
    "    infer_prompt_c = infer_prompt.as_query_component(\n",
    "        partial={\"pred_context\": pred_context}\n",
    "    )\n",
    "    infer_pipeline = QueryPipeline(chain=[infer_prompt_c, llm, preds_output_parser])\n",
    "    preds = infer_pipeline.run(query)\n",
    "\n",
    "    print(f\"PREDS: {preds}\")\n",
    "    all_nodes = []\n",
    "    for pred in preds:\n",
    "        nodes = retriever.retrieve(str(pred))\n",
    "        all_nodes.extend(nodes)\n",
    "\n",
    "    reranker = RankGPTRerank(\n",
    "        llm=llm,\n",
    "        top_n=reranker_top_n,\n",
    "        rankgpt_rerank_prompt=rerank_prompt,\n",
    "        # verbose=True,\n",
    "    )\n",
    "    reranked_nodes = reranker.postprocess_nodes(all_nodes, query_str=query)\n",
    "    return [n.get_content() for n in reranked_nodes]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ea0639-b929-4754-a836-6199e36b970d",
   "metadata": {},
   "source": [
    "## Run Over Sample Data\n",
    "\n",
    "Now we're ready to run over some sample data! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35269968-d98a-4e77-a320-536f146e58ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = get_samples(dataset, sample_size=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e7498f3-6adc-4200-8b12-76cd7939cc58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "PREDS: ['fluid overload', 'acute respiratory distress syndrome', 'anxiety', 'myocardial insufficiency', 'hypervolemia', 'hypovolemia', 'respiratory distress', 'allergic reaction', 'diarrhea', 'rash']\n",
      "1\n",
      "PREDS: ['fever', 'dizziness', 'dyspnea on exertion', 'intermittent chest pain', 'palpitations']\n",
      "2\n",
      "PREDS: ['azathioprine-induced myelotoxicity', 'drug-induced agranulocytosis']\n",
      "3\n",
      "PREDS: ['There is no information provided about adverse drug reactions in the given text context. Therefore', 'it is not possible to make any predictions about adverse drug reactions.']\n",
      "4\n",
      "PREDS: ['painful swelling in lymph nodes', 'weight loss', 'night sweats', 'hepatosplenomegaly', 'generalized lymphadenopathy', 'skin disorders', 'bone marrow disorders', 'blood disorders', 'misorientation of body segments', 'excessive backward pelvic tilt', 'excessive kyphosis', 's-shaped scoliosis', 'excessive pelvic obliquity', 'flat right foot contact', 'limited ankle dorsiflexion', 'toe walking', 'muscle weakness', 'limited mobility', 'misalignment of body segments', 'gait disturbances', 'peripheral neuropathy', 'myopathy', 'atrophy in gastrocnemius muscle', 'compensatory strategy', 'contractures', 'body posture correction', 'gait pattern improvement.']\n"
     ]
    }
   ],
   "source": [
    "reaction_retriever = index.as_retriever(similarity_top_k=2)\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo-16k\")\n",
    "pred_context = \"\"\"\\\n",
    "The output predictins should be a list of comma-separated adverse \\\n",
    "drug reactions. \\\n",
    "\"\"\"\n",
    "\n",
    "reranker_top_n = 10\n",
    "\n",
    "pred_reactions = []\n",
    "gt_reactions = []\n",
    "for idx, sample in enumerate(samples):\n",
    "    print(idx)\n",
    "    cur_pred_reactions = infer_retrieve_rerank(\n",
    "        sample[\"text\"],\n",
    "        reaction_retriever,\n",
    "        llm,\n",
    "        pred_context,\n",
    "        reranker_top_n=reranker_top_n,\n",
    "    )\n",
    "    cur_gt_reactions = sample[\"reactions\"]\n",
    "\n",
    "    pred_reactions.append(cur_pred_reactions)\n",
    "    gt_reactions.append(cur_gt_reactions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4717069f-19d0-43aa-afeb-e0cb17cb1d79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['agranulocytosis',\n",
       " 'haematotoxicity',\n",
       " 'bone marrow toxicity',\n",
       " 'infantile genetic agranulocytosis']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_reactions[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5223f8e2-c08b-459d-9ded-b25b90899616",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['bone marrow toxicity',\n",
       " 'cytomegalovirus infection',\n",
       " 'cytomegalovirus mucocutaneous ulcer',\n",
       " 'febrile neutropenia',\n",
       " 'leukoplakia',\n",
       " 'odynophagia',\n",
       " 'oropharyngeal candidiasis',\n",
       " 'pancytopenia',\n",
       " 'product use issue',\n",
       " 'red blood cell poikilocytes present',\n",
       " 'vitamin d deficiency']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_reactions[2]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_hub",
   "language": "python",
   "name": "llama_hub"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
